import torch
import torch.nn as nn
from nn_architectures import NetVAE, NetClassifier, NetDPEncoder, NetDPDecoder
from torch.distributions import Normal, Laplace, Categorical, OneHotCategorical
from torchvision.utils import make_grid
import torch.nn.functional as F
import numpy as np
from privacy_functions import mahalanobis_clip
import pandas as pd
from plotting_functions import plot_labels

class VAEBase(nn.Module):
    def __init__(self, opt):
        super().__init__()

        self.device = opt.device
        self.batch_size = opt.batch_size
        self.image_dim = opt.image_dim
        self.rep_dim = opt.rep_dim
        self.md = opt.md
        self.latent_distn = opt.latent_distn

        # self.network = NetRepLearning(opt).to(opt.device)

        self.register_buffer('prior_mu', torch.zeros(1))
        self.prior_std = torch.tensor(opt.prior_std)
        if opt.latent_distn == 'Laplace':
            self.prior_scale = opt.prior_std / torch.sqrt(2 * torch.ones(1))
        self.posterior_std = opt.posterior_std
        self.diagonal_x_std = opt.diagonal_x_std

        self.sigmoid = nn.Sigmoid()
        self.ce_loss = nn.CrossEntropyLoss()
        self.tabular = opt.tabular
        if opt.tabular:
            self.n_cont = opt.n_continuous_features
            self.n_cats = opt.ncat_of_cat_features
        self.synthetic_generation = opt.synthetic_generation
        self.n_categories = opt.n_categories
            
    def posterior(self, x):
        r_params = self.network.posterior_r_net(x)
        if self.posterior_std is None:
            r_mu, r_std = torch.split(r_params, self.rep_dim, dim=-1)
            # z_std = torch.clamp(torch.exp(z_std), 1e-3, 1.5)
            r_std = F.softplus(r_std) + 1e-5
        else:
            r_mu = r_params
            r_std = self.posterior_std

        if self.md is not None:
            r_mu = mahalanobis_clip(r_mu, self.md)

        if self.latent_distn == 'Laplace':
            r_scale = r_std / np.sqrt(2)
            r = Laplace(r_mu, r_scale).rsample()
        else:
            eps_z = torch.randn_like(r_mu).to(self.device)
            r = eps_z.mul(r_std).add_(r_mu)
        return r, (r_mu, r_std)

    def get_data_representatation(self, x, clip=None, data_loader=True):
        if data_loader:
            r_params = self.network.posterior_r_net(x.unsqueeze(0).to(self.device))
        else:
            r_params = self.network.posterior_r_net(x.to(self.device))
        if self.posterior_std is None:
            r_mu, _ = torch.split(r_params, self.rep_dim, dim=-1)
        else:
            r_mu = r_params

        if clip is not None:
            r_mu = mahalanobis_clip(r_mu, clip)

        return r_mu

    def get_data_reconstruction(self, r_mu, privacy_inducing_std, clip=None):
        if clip is not None:
            r_mu = mahalanobis_clip(r_mu, clip)

        if self.latent_distn == 'Laplace':
            scale = privacy_inducing_std / np.sqrt(2)
            r = Laplace(r_mu, scale).rsample()
        else:
            eps_z = torch.randn_like(r_mu).to(self.device)
            r = eps_z.mul(privacy_inducing_std).add_(r_mu)

        (x_mu, _), p_x_cat, _ = self.model(r)
        if self.tabular:
            one_hot_max = lambda cat_x: (cat_x==torch.max(cat_x,axis=1).values.reshape(-1,1)).float()
            max_p_x_cat = torch.cat([one_hot_max(p_feature) for p_feature in
                                     torch.split(p_x_cat, self.n_cats, dim=1)],
                                    dim=1)
            reconstruction = torch.cat([x_mu, max_p_x_cat.float()], dim=1)
        else:
            reconstruction = x_mu
        return reconstruction


class VAE(VAEBase):
    def __init__(self, opt):
        super().__init__(opt)

    def model(self, r, cat_softmax=True):
        if self.synthetic_generation:
            x_y_params = self.network.gen_x_net(r)
            x_params, p_y = torch.split(x_y_params, [x_y_params.size(1) - self.n_categories, self.n_categories], 1)
            p_y = F.softmax(p_y, dim=1)
        else:
            x_params = self.network.gen_x_net(r)
            p_y = None
        if self.diagonal_x_std:
            x_params = x_params.view((x_params.size(0), 2, *self.image_dim))
            x_mu = x_params[:, 0]
            x_std = x_params[:, 1]
            x_std = F.softplus(x_std) + 1e-5
        else:
            x_mu = x_params.view((x_params.size(0), *self.image_dim))
            x_std = self.network.x_std
        if self.tabular:
            if cat_softmax:
                p_x_cat = torch.cat([F.softmax(cat_var, dim=1) for cat_var in
                                    torch.split(x_mu[:, self.n_cont:], self.n_cats, dim=1)],
                                    dim=1)
            else:
                p_x_cat = x_mu[:, self.n_cont:]
            x_mu = x_mu[:, :self.n_cont]
            x_std = x_std[:, :self.n_cont] if self.diagonal_x_std else x_std
        else:
            p_x_cat = None
        return (x_mu, x_std), p_x_cat, p_y

    def loss(self, x, y, encoder=None, decoder=None):
        """
        :param x: the input image/ table
        :param y: the label (set to None if no self.synthetic_generation)
        :param encoder: optionally use a different encoder network
        :param decoder: optionally use a different decoder network
        :return: lower bound on the log likelihood
        """
        if self.synthetic_generation:
            posterior_in = torch.cat([x.view(self.batch_size, -1), y.float()], dim=1)
        else:
            posterior_in = x

        r, (r_mu, r_std) = self.posterior(posterior_in) if encoder is None else encoder.posterior(posterior_in)
        (x_mu, x_std), p_x_cat, p_y = self.model(r,cat_softmax=False) if decoder is None else decoder.model(r,cat_softmax=False)
        x_std = x_std.to(self.device)

        if self.tabular:
            x_cat = x[:, self.n_cont:]
            x_cts = x[:, :self.n_cont]
            log_p_x_cts = Normal(x_mu, x_std).log_prob(x_cts)
            rec_cts = torch.sum(log_p_x_cts) / self.batch_size
            rec_cat = - sum([self.ce_loss(p_unnorm, torch.max(x, dim=1)[1].long())
                                         for p_unnorm, x in zip(torch.split(p_x_cat, self.n_cats, dim=1),
                                                                torch.split(x_cat, self.n_cats, dim=1))])
            rec_x = rec_cat + rec_cts
        else:
            rec_x = Normal(x_mu, x_std).log_prob(x)
            rec_x = torch.sum(rec_x) / self.batch_size

        if self.synthetic_generation:
            rec_y = torch.sum(torch.log(torch.sum(y * p_y, dim=1))) / self.batch_size
        else:
            rec_y = 0
    
        if self.latent_distn == 'Laplace':
            p_r = torch.sum(Laplace(self.prior_mu.to(self.device),
                                    self.prior_scale.to(self.device)).log_prob(r))
            r_scale = r_std / np.sqrt(2)
            q_r = torch.sum(Laplace(r_mu, r_scale).log_prob(r))
        else:
            p_r = torch.sum(Normal(self.prior_mu.to(self.device), self.prior_std.to(self.device)).log_prob(r))
            q_r = torch.sum(Normal(r_mu, r_std).log_prob(r))

        kl = (q_r - p_r) / self.batch_size

        loss = -rec_x - rec_y + kl

        return loss, -rec_x, -rec_y, kl

    def reconstruct(self, x, y, epoch, writer, stage=1, encoder=None, decoder=None):
        x = x.to(self.device)
        if self.synthetic_generation:
            posterior_in = torch.cat([x.view(x.size(0), -1), y.float()], dim=1)
        else:
            posterior_in = x

        q_z_sample, _ = self.posterior(posterior_in) if encoder is None else encoder.posterior(posterior_in)
        (p_x_mu, _), p_x_cat, p_y = self.model(q_z_sample) if decoder is None else decoder.model(q_z_sample)

        if self.tabular:
            true_x_cat = [torch.argmax(x_i, dim=1) for x_i in torch.split(x[:, self.n_cont:], self.n_cats, dim=1)]
            for i, cat_probs in enumerate(torch.split(p_x_cat, self.n_cats, dim=1)):
                plot_labels(cat_probs,
                            true_x_cat[i],
                            epoch, writer, 'Stage{}/reconstructions_catfeature{}'.format(stage, i))
            print("not currently plotting continuous feature reconstructions")
        else:
            x_with_recon = torch.cat((x, p_x_mu))
            writer.add_image(tag='Stage{}/reconstructions'.format(stage),
                             img_tensor=make_grid(self.sigmoid(x_with_recon)),
                             global_step=epoch)

        if self.synthetic_generation:
            plot_labels(p_y, torch.argmax(y, dim=1), epoch, writer, 'Stage{}/samples_label'.format(stage))

        print('Epoch: {}\tReconstructions generated'.format(epoch))

    def sample(self, epoch, writer, num=64, stage=1):
        if self.latent_distn == 'Laplace':
            prior_distn = Laplace(self.prior_mu.to(self.device), self.prior_scale.to(self.device))
        else:
            prior_distn = Normal(self.prior_mu.to(self.device), self.prior_std.to(self.device))
        z = prior_distn.sample(torch.Size([num, self.rep_dim]))[:, :, 0]
        (p_x_mu, _), p_x_cat, p_y = self.model(z)

        if self.tabular:
            for i, cat_probs in enumerate(torch.split(p_x_cat, self.n_cats, dim=1)):
                plot_labels(cat_probs, None, epoch, writer, 'Stage{}/samples_catfeature{}'.format(stage, i))
            print("not currently plotting continuous feature samples")
        else:
            writer.add_image('Stage{}/samples'.format(stage),
                             make_grid(self.sigmoid(p_x_mu), normalize=True), epoch)

        if self.synthetic_generation:
            plot_labels(p_y, None, epoch, writer, 'Stage{}/samples_label'.format(stage))

        print('Epoch: {}\tSamples generated'.format(epoch))

    def prior_sample(self, num):
        if self.latent_distn == 'Laplace':
            r = Laplace(self.prior_mu.to(self.device),
                        self.prior_scale.to(self.device)).sample(torch.Size([num, self.rep_dim]))[:, :, 0]
        else:
            r = torch.randn(num, self.rep_dim).to(self.device)
        (x_mu, x_std), p_x_cat, p_y = self.model(r)
        x_std = x_std.to(self.device)
        x = Normal(x_mu, x_std).rsample()
        if self.tabular:
            x_cat = [OneHotCategorical(probs=cat_probs).sample() for cat_probs in torch.split(p_x_cat, self.n_cats, dim=1)]
            x = torch.cat([x] + x_cat, dim=1)

        y = Categorical(probs=p_y).sample()
        return x, y

    def sample_dataframe(self, n_samples):
        if not self.tabular:
            raise NotImplementedError("only for tabular data")
        prior_distn = Laplace(self.prior_mu.to(self.device), self.prior_scale.to(self.device))
        z = prior_distn.sample(torch.Size([n_samples, self.rep_dim]))[:, :, 0]
        (p_x_mu, _), p_x_cat, _ = self.model(z)
        generated_df = pd.DataFrame(p_x_mu.cpu().numpy()[:, :])

        for i, feature in enumerate(torch.split(p_x_cat, self.n_cats, dim=1)):
            generated_df[4+i] = feature.argmax(dim=1).cpu().numpy()
        
        return generated_df


class NonDPModel(VAE):
    def __init__(self, opt):
        super().__init__(opt)
        self.network = NetVAE(opt).to(opt.device)


class DPEncoder(VAE):
    def __init__(self, opt):
        super().__init__(opt)
        self.network = NetDPEncoder(opt).to(opt.device)


class DPDecoder(VAE):
    def __init__(self, opt):
        super().__init__(opt)
        self.network = NetDPDecoder(opt).to(opt.device)


class ClassifierBase(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.device = opt.device
        self.batch_size = opt.batch_size
        self.image_dim = opt.image_dim
        self.synthetic_generation = opt.synthetic_generation
        self.data_join_task = opt.data_join_task

        self.network = NetClassifier(opt).to(opt.device)

    def model(self, x):
        return self.network.classifier_net(x)


class NaiveClassifier(ClassifierBase):
    def __init__(self, opt):
        super().__init__(opt)
        self.ce_loss = nn.CrossEntropyLoss()

    def loss(self, noisy_input, input, noisy_label, label, eval_accuracy=False):

        if noisy_input is not None and noisy_label is not None:     # True for synthetic test loss
            y_probs = self.model(noisy_input)
            loss = self.ce_loss(y_probs, noisy_label.squeeze(1).long())
        else:
            loss = 0.

        if eval_accuracy:
            if self.data_join_task:
                accuracy_clean = 0.
            else:
                y_probs_clean = self.model(input)
                y_pred_clean = torch.argmax(y_probs_clean, dim=1, keepdim=True)
                accuracy_clean = (y_pred_clean.float() == label.float()).sum().item() / self.batch_size

            if noisy_input is not None and noisy_label is not None:     # True for synthetic test loss
                y_probs_noisy = self.model(noisy_input)
                y_pred_noisy = torch.argmax(y_probs_noisy, dim=1, keepdim=True)
                accuracy_noisy = (y_pred_noisy.float() == label.float()).sum().item() / self.batch_size
            else:
                accuracy_noisy = 0.
        else:
            accuracy_clean = 0.
            accuracy_noisy = 0.
        return loss, (accuracy_clean, accuracy_noisy)


class LabelNoiseClassifier(ClassifierBase):
    def __init__(self, opt):
        super().__init__(opt)
        self.y_noise = opt.y_noise
        self.n_categories = opt.n_categories

    def log_prob_y_tilde(self, p_y_probs, y_tilde):
        """
        returns the log probability log p(y_tilde|x) = log sum_y{p(y_tilde|y) p(y|x)} evaluated at y_tilde.
        """
        y_tilde_one_hot = torch.zeros_like(p_y_probs).scatter_(1, y_tilde.long(), 1)
        p_y_tilde_transition_mat = torch.eye(p_y_probs.shape[1]) * (1 - self.y_noise) + \
                                   (1 - torch.eye(p_y_probs.shape[1])) * (self.y_noise / (self.n_categories - 1))
        p_y_tilde_transition_mat = p_y_tilde_transition_mat.to(self.device)

        p_y_tilde = torch.sum(torch.matmul(p_y_tilde_transition_mat,
                                           p_y_probs.unsqueeze(-1)
                                           ).squeeze(-1) * y_tilde_one_hot,
                              dim=1)

        return torch.log(p_y_tilde)

    def loss(self, noisy_input, input, noisy_label, label, eval_accuracy=False):

        y_probs = self.model(noisy_input)
        log_prob_y_tilde = self.log_prob_y_tilde(y_probs, noisy_label)
        loss = -torch.sum(log_prob_y_tilde) / self.batch_size

        if eval_accuracy:
            y_probs_clean = self.model(input)
            y_pred_clean = torch.argmax(y_probs_clean, dim=1, keepdim=True)
            accuracy_clean = (y_pred_clean.float() == label.float()).sum().item() / self.batch_size

            y_probs_noisy = self.model(noisy_input)
            y_pred_noisy = torch.argmax(y_probs_noisy, dim=1, keepdim=True)
            accuracy_noisy = (y_pred_noisy.float() == label.float()).sum().item() / self.batch_size
        else:
            accuracy_clean = 0.
            accuracy_noisy = 0.
        return loss, (accuracy_clean, accuracy_noisy)
